from context_fid import Context_FID
from cross_correlation import CrossCorrelLoss
from discriminative_metric_torch import discriminative_score_metrics
from predictive_metric_torch import predictive_score_metrics
import numpy as np
import os
from metric_utils import visualization
import torch
import argparse


def parse_args():
    parser = argparse.ArgumentParser(description='Performance Evaluation Script')
    parser.add_argument('--name', type = str, default = 'ETTh1')
    parser.add_argument('--seq_len', type = int, default = 24)
    args = parser.parse_args()

    return args

if __name__ == '__main__':
    args = parse_args()

    data_name = args.name
    seq_len = args.seq_len

    output_dir = f'OUTPUT/{data_name}'

    ori_path = os.path.join(output_dir, 'samples', f'{data_name}_norm_truth_{seq_len}_train.npy')
    # For sines dataset, the norm truth file is named as ground_truth instead of norm_truth.
    # Uncomment the following line and comment the above line if you are using sines dataset.
    # ori_path = os.path.join(output_dir, 'samples', f'{data_name}_ground_truth_{seq_len}_train.npy')

    gen_path = os.path.join(output_dir, f'ddpm_fake_{data_name}.npy')

    ori_data = np.load(ori_path)
    gen_data = np.load(gen_path)[:len(ori_data), :, :]
    print(f'Shape of original data: {ori_data.shape}, generated data: {gen_data.shape}')
    print(f'Min and max of original data: {np.min(ori_data)}, {np.max(ori_data)}, generated data: {np.min(gen_data)}, {np.max(gen_data)}')

    metric_results = dict()
    iter_num = 10

    Context_FID_score = list()
    print('Start Context_FID')
    for i in range(iter_num):
        temp_fid = Context_FID(ori_data, gen_data)
        Context_FID_score.append(temp_fid)
        print(f'Context_FID score: {temp_fid}')
    metric_results['Context_FID_mean'] = np.mean(Context_FID_score)
    metric_results['Context_FID_std'] = np.std(Context_FID_score)
    print(f'Mean Context_FID score: {metric_results["Context_FID_mean"]} ± {metric_results["Context_FID_std"]}')

    Cross_correlation_score = list()
    print('Start Cross_correlation_score')
    x_real = torch.from_numpy(ori_data)
    x_fake = torch.from_numpy(gen_data)

    for i in range(iter_num):
        idx = np.random.randint(0, x_real.shape[0], x_real.shape[0])
        temp_cross = CrossCorrelLoss(x_real[idx, :, :]).compute(x_fake[idx, :, :])
        Cross_correlation_score.append(temp_cross)
        print(f'Cross_correlation_score score: {temp_cross}')
    metric_results['Cross_correlation_score'] = np.mean(Cross_correlation_score)
    metric_results['Cross_correlation_score_std'] = np.std(Cross_correlation_score)
    print(f'Mean Cross_correlation_score score: {metric_results["Cross_correlation_score"]} ± {metric_results["Cross_correlation_score_std"]}')
    del x_real, x_fake

    discriminative_score = list()
    print('Start discriminative_score_metrics')
    for i in range(iter_num):
        temp_disc = discriminative_score_metrics(ori_data, gen_data)
        discriminative_score.append(temp_disc)
        print(f'discriminative_score_metrics score: {temp_disc}')
    metric_results['discriminative'] = np.mean(discriminative_score)
    metric_results['discriminative_std'] = np.std(discriminative_score)
    print(f'Mean discriminative_score_metrics score: {metric_results["discriminative"]} ± {metric_results["discriminative_std"]}')

    predictive_score = list()
    print('Start predictive_score_metrics')
    for i in range(iter_num):
        temp_predict = predictive_score_metrics(ori_data, gen_data)
        predictive_score.append(temp_predict)
        print(f'predictive_score_metrics score: {temp_predict}')
    metric_results['predictive'] = np.mean(predictive_score)
    metric_results['predictive_std'] = np.std(predictive_score)
    print(f'Mean predictive_score_metrics score: {metric_results["predictive"]} ± {metric_results["predictive_std"]}')

    print(metric_results)

    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, 'metric_results.txt')
    with open(output_file, 'w') as f:
        for key, value in metric_results.items():
            f.write(f'{key}: {value}\n')

    visualization(ori_data, gen_data, 'kernel', compare = None, save_dir = output_dir)
    ori_data = ori_data[:3000, :, :]
    gen_data = gen_data[:3000, :, :]
    visualization(ori_data, gen_data, 'pca', compare = None, save_dir = output_dir)
    visualization(ori_data, gen_data, 'tsne', compare = None, save_dir = output_dir)

